"""Estimates of MMD and witness function"""
import numpy as np 
from sklearn.gaussian_process.kernels import RBF
from scipy.spatial import distance

def mmd_u(x, y, gamma=None):
	'''Unbiased estimator of the MMD'''
	if len(x.shape) ==1:
		x = x.reshape(-1, 1) 
		y = y.reshape(-1, 1)

	n, nfeatures = x.shape
	m, mfeatures = y.shape
	
	if gamma is None: 
		gamma = kernelwidthPair(x, y)

	Kxx = RBF(length_scale=gamma).__call__(x)
	Kyy = RBF(length_scale=gamma).__call__(y)
	Kxy = RBF(length_scale=gamma).__call__(x, y)
	
	# Term 1
	c1 = 1 / ( n * (n - 1))
	A = np.sum(Kxx - np.diag(np.diagonal(Kxx)))

	# Term II
	c2 = 1 / (m * (m - 1))
	B = np.sum(Kyy - np.diag(np.diagonal(Kyy)))

	# Term III
	c3 = 1 / (m * n)
	C = np.sum(Kxy)

	# estimate MMD
	mmd_est = c1 * A + c2 * B - 2 * c3 * C
	return mmd_est


def mmd_b(x, y, gamma=None):
	'''Biased estimator for MMD'''
	if len(x.shape) ==1:
		x = x.reshape(-1, 1) 
		y = y.reshape(-1, 1)

	n, nfeatures = x.shape
	m, mfeatures = y.shape
	
	if gamma is None: 
		gamma = kernelwidthPair(x, y)

	Kxx = RBF(length_scale=gamma).__call__(x)
	Kyy = RBF(length_scale=gamma).__call__(y)
	Kxy = RBF(length_scale=gamma).__call__(x, y)
	
	# Term 1
	c1 = 1 / (n**2)
	A = np.sum(Kxx)

	# Term II
	c2 = 1 / (m**2)
	B = np.sum(Kyy)

	# Term III
	c3 = 1 / (m * n)
	C = np.sum(Kxy)

	# estimate MMD
	mmd_est = c1 * A + c2 * B - 2 * c3 * C
	return mmd_est

def kernelwidthPair(x1, x2):
    '''Implementation of the median heuristic. See Gretton 2012
         Pick sigma such that the exponent of exp(- ||x-y|| / (2*sigma2)),
         in other words ||x-y|| / (2*sigma2),  equals 1 for the median distance x
         and y of all distances between points from both data sets X and Y.
    '''

    D = distance.cdist(x1, x2, 'euclidean')

    sigma = np.median(D)
    
    #sigma = np.sqrt(np.median(D) / 2)
    
    return sigma



def witness(x,y,z, epsilon, gamma=None):
	'''Witness function 
	Args: 
		x: array, observed samples from first dist
		y: array, observed samples from second dist 
		z: array, locations where we want to estimate the witness function 
		gamma: scalar, bandwidth for the kernels 
	returns: 
		array of the same length as z
	'''
	
	if gamma is None: 
		gamma = kernelwidthPair(x, y)

	if len(z.shape) == 1: 
		z = z.reshape(-1, 1)
	if len(x.shape) ==1: 
		x = x.reshape(-1, 1)
	if len(y.shape) ==1: 
		y = y.reshape(-1, 1)

	Kxz = np.mean(RBF(length_scale=gamma).__call__(z, x), axis=1)
	Kyz = np.mean(RBF(length_scale=gamma).__call__(z, y), axis=1)
	Kzz = np.mean(RBF(length_scale=gamma).__call__(z, z), axis=1)

	term = 4 * epsilon * (1 - epsilon) * Kxz 
	term = term - 4 * epsilon * (1 + epsilon) * Kyz 
		
	return term



def witness_simple(x,y,z, gamma=None):
	'''Witness function 
	Args: 
		x: array, observed samples from first dist
		y: array, observed samples from second dist 
		z: array, locations where we want to estimate the witness function 
		gamma: scalar, bandwidth for the kernels 
	returns: 
		array of the same length as z
	'''
	
	if gamma is None: 
		gamma = kernelwidthPair(x, y)

	if len(z.shape) == 1: 
		z = z.reshape(-1, 1)
	if len(x.shape) ==1: 
		x = x.reshape(-1, 1)
	if len(y.shape) ==1: 
		y = y.reshape(-1, 1)

	Kxz = RBF(length_scale=gamma).__call__(z, x)
	Kyz = RBF(length_scale=gamma).__call__(z, y)
	
	return np.mean(Kxz, axis=1) - np.mean(Kyz, axis=1)